import warnings

from .base_scheduler import BaseScheduler

class LambdaSL(BaseScheduler):
    """Sets the sparsity level of each parameter group to the final sl
    times a given function. When last_epoch=-1, sets initial sl as zero.
    Args:
        sparsifier (BaseSparsifier): Wrapped sparsifier.
        sl_lambda (function or list): A function which computes a multiplicative
            factor given an integer parameter epoch, or a list of such
            functions, one for each group in sparsifier.param_groups.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.
    Example:
        >>> # Assuming sparsifier has two groups.
        >>> lambda1 = lambda epoch: epoch // 30
        >>> lambda2 = lambda epoch: 0.95 ** epoch
        >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2])
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False):
        self.sparsifier = sparsifier

        if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple):
            self.sl_lambdas = [sl_lambda] * len(sparsifier.module_groups)
        else:
            if len(sl_lambda) != len(sparsifier.module_groups):
                raise ValueError("Expected {} lr_lambdas, but got {}".format(
                    len(sparsifier.module_groups), len(sl_lambda)))
            self.sl_lambdas = list(sl_lambda)
        super(LambdaSL, self).__init__(sparsifier, last_epoch, verbose)

    def get_sl(self):
        if not self._get_sl_called_within_step:
            warnings.warn(
                "To get the last sparsity level computed by the scheduler, "
                "please use `get_last_sl()`.")
        return [base_sl * lmbda(self.last_epoch)
                for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)]
